# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

# ruff: noqa

import os
import sys
import tempfile
import subprocess
from collections.abc import Callable
from typing import Literal, Any, ClassVar
from pathlib import Path
from colorama import Fore

from utu.config import ToolkitConfig
from utu.tools import AsyncBaseToolkit
from utu.utils import get_logger

logger = get_logger(__name__)


class InterpreterError(Exception):
    r"""Exception raised for errors that can be solved by regenerating code"""

    pass


class SubprocessInterpreter:
    r"""SubprocessInterpreter is a class for executing code files or code
    strings in a subprocess.

    This class handles the execution of code in different scripting languages
    (currently Python and Bash) within a subprocess, capturing their
    stdout and stderr streams, and allowing user checking before executing code
    strings.

    Args:
        require_confirm (bool, optional): If True, prompt user before running
            code strings for security. (default: :obj:`True`)
        print_stdout (bool, optional): If True, print the standard output of
            the executed code. (default: :obj:`False`)
        print_stderr (bool, optional): If True, print the standard error of the
            executed code. (default: :obj:`True`)
        execution_timeout (int, optional): Maximum time in seconds to wait for
            code execution to complete. (default: :obj:`60`)
        working_dir (str, optional): The working directory for code execution.
            If None, uses current directory. (default: :obj:`None`)
    """

    _CODE_EXECUTE_CMD_MAPPING: ClassVar[dict[str, dict[str, str]]] = {
        "python": {"posix": "python {file_name}", "nt": "python {file_name}"},
        "bash": {"posix": "bash {file_name}", "nt": "bash {file_name}"},
        "r": {"posix": "Rscript {file_name}", "nt": "Rscript {file_name}"},
    }

    _CODE_EXTENSION_MAPPING: ClassVar[dict[str, str]] = {
        "python": "py",
        "bash": "sh",
        "r": "R",
    }

    _CODE_TYPE_MAPPING: ClassVar[dict[str, str]] = {
        "python": "python",
        "py3": "python",
        "python3": "python",
        "py": "python",
        "shell": "bash",
        "bash": "bash",
        "sh": "bash",
        "r": "r",
        "R": "r",
    }

    def __init__(
        self,
        require_confirm: bool = True,
        print_stdout: bool = False,
        print_stderr: bool = True,
        execution_timeout: int = 60,
        working_dir: str = None,
    ) -> None:
        self.require_confirm = require_confirm
        self.print_stdout = print_stdout
        self.print_stderr = print_stderr
        self.execution_timeout = execution_timeout
        self.working_dir = working_dir

    def run_file(
        self,
        file: Path,
        code_type: str,
    ) -> str:
        r"""Executes a code file in a subprocess and captures its output.

        Args:
            file (Path): The path object of the file to run.
            code_type (str): The type of code to execute (e.g., 'python',
                'bash').

        Returns:
            str: A string containing the captured stdout and stderr of the
                executed code.
        """
        if not file.is_file():
            return f"{file} is not a file."
        code_type = self._check_code_type(code_type)
        if self._CODE_TYPE_MAPPING[code_type] == "python":
            # For Python code, use ast to analyze and modify the code
            import ast

            import astor

            with open(file, encoding="utf-8") as f:
                source = f.read()

            # Parse the source code
            try:
                tree = ast.parse(source)
                # Get the last node
                if tree.body:
                    last_node = tree.body[-1]
                    # Handle expressions that would normally not produce output
                    # For example: In a REPL, typing '1 + 2' should show '3'

                    if isinstance(last_node, ast.Expr):
                        # Only wrap in print(repr()) if it's not already a
                        # print call
                        if not (
                            isinstance(last_node.value, ast.Call)
                            and isinstance(last_node.value.func, ast.Name)
                            and last_node.value.func.id == "print"
                        ):
                            # Transform the AST to wrap the expression in print
                            # (repr())
                            # Example transformation:
                            #   Before: x + y
                            #   After:  print(repr(x + y))
                            tree.body[-1] = ast.Expr(
                                value=ast.Call(
                                    # Create print() function call
                                    func=ast.Name(id="print", ctx=ast.Load()),
                                    args=[
                                        ast.Call(
                                            # Create repr() function call
                                            func=ast.Name(id="repr", ctx=ast.Load()),
                                            # Pass the original expression as
                                            # argument to repr()
                                            args=[last_node.value],
                                            keywords=[],
                                        )
                                    ],
                                    keywords=[],
                                )
                            )
                    # Fix missing source locations
                    ast.fix_missing_locations(tree)
                    # Convert back to source
                    modified_source = astor.to_source(tree)
                    # Create a temporary file with the modified source
                    temp_file = self._create_temp_file(modified_source, "py")
                    cmd = ["python", str(temp_file)]
            except (SyntaxError, TypeError, ValueError) as e:
                logger.warning(f"Failed to parse Python code with AST: {e}")
                platform_type = "posix" if os.name != "nt" else "nt"
                cmd_template = self._CODE_EXECUTE_CMD_MAPPING[code_type][platform_type]
                base_cmd = cmd_template.split()[0]

                # Check if command is available
                if not self._is_command_available(base_cmd):
                    raise InterpreterError(
                        f"Command '{base_cmd}' not found. Please ensure it is installed and available in your PATH."
                    )

                cmd = [base_cmd, str(file)]
        else:
            # For non-Python code, use standard execution
            platform_type = "posix" if os.name != "nt" else "nt"
            cmd_template = self._CODE_EXECUTE_CMD_MAPPING[code_type][platform_type]
            base_cmd = cmd_template.split()[0]  # Get 'python', 'bash', etc.

            # Check if command is available
            if not self._is_command_available(base_cmd):
                raise InterpreterError(
                    f"Command '{base_cmd}' not found. Please ensure it is installed and available in your PATH."
                )

            cmd = [base_cmd, str(file)]

        # Get current Python executable's environment
        env = os.environ.copy()

        # On Windows, ensure we use the correct Python executable path
        if os.name == "nt":
            python_path = os.path.dirname(sys.executable)
            if "PATH" in env:
                env["PATH"] = python_path + os.pathsep + env["PATH"]
            else:
                env["PATH"] = python_path

        try:
            proc = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                env=env,
                cwd=self.working_dir,  # Set working directory
                shell=False,  # Never use shell=True for security
            )
            # Add timeout to prevent hanging processes
            stdout, stderr = proc.communicate(timeout=self.execution_timeout)
            return_code = proc.returncode
        except subprocess.TimeoutExpired:
            proc.kill()
            stdout, stderr = proc.communicate()
            return_code = proc.returncode
            timeout_msg = f"Process timed out after {self.execution_timeout} seconds and was terminated."
            stderr = f"{stderr}\n{timeout_msg}"

        # Clean up temporary file if it was created
        temp_file_to_clean = locals().get("temp_file")
        if temp_file_to_clean is not None:
            try:
                if temp_file_to_clean.exists():
                    try:
                        temp_file_to_clean.unlink()
                    except PermissionError:
                        # On Windows, files might be locked
                        logger.warning(f"Could not delete temp file {temp_file_to_clean} (may be locked)")
            except Exception as e:
                logger.warning(f"Failed to cleanup temporary file: {e}")

        if self.print_stdout and stdout:
            print("======stdout======")
            print(Fore.GREEN + stdout + Fore.RESET)
            print("==================")
        if self.print_stderr and stderr:
            print("======stderr======")
            print(Fore.RED + stderr + Fore.RESET)
            print("==================")

        # Build the execution result
        exec_result = ""
        if stdout:
            exec_result += stdout
        if stderr:
            exec_result += f"(stderr: {stderr})"
        if return_code != 0:
            error_msg = f"(Execution failed with return code {return_code})"
            if not stderr:
                exec_result += error_msg
            elif error_msg not in stderr:
                exec_result += error_msg
        return exec_result

    def run(
        self,
        code: str,
        code_type: str,
    ) -> str:
        r"""Generates a temporary file with the given code, executes it, and
            deletes the file afterward.

        Args:
            code (str): The code string to execute.
            code_type (str): The type of code to execute (e.g., 'python',
                'bash').

        Returns:
            str: A string containing the captured stdout and stderr of the
                executed code.

        Raises:
            InterpreterError: If the user declines to run the code or if the
                code type is unsupported.
        """
        code_type = self._check_code_type(code_type)

        # Print code for security checking
        if self.require_confirm:
            logger.info(f"The following {code_type} code will run on your computer: {code}")
            while True:
                choice = input("Running code? [Y/n]:").lower().strip()
                if choice in ["y", "yes", "ye", ""]:
                    break
                elif choice in ["no", "n"]:
                    raise InterpreterError(
                        "Execution halted: User opted not to run the code. "
                        "This choice stops the current operation and any "
                        "further code execution."
                    )
                else:
                    print("Please enter 'y' or 'n'.")

        temp_file_path = None
        temp_dir = None
        try:
            temp_file_path = self._create_temp_file(code=code, extension=self._CODE_EXTENSION_MAPPING[code_type])
            temp_dir = temp_file_path.parent
            return self.run_file(temp_file_path, code_type)
        finally:
            # Clean up temp file and directory
            try:
                if temp_file_path and temp_file_path.exists():
                    try:
                        temp_file_path.unlink()
                    except PermissionError:
                        # On Windows, files might be locked
                        logger.warning(f"Could not delete temp file {temp_file_path}")

                if temp_dir and temp_dir.exists():
                    try:
                        import shutil

                        shutil.rmtree(temp_dir, ignore_errors=True)
                    except Exception as e:
                        logger.warning(f"Could not delete temp directory: {e}")
            except Exception as e:
                logger.warning(f"Error during cleanup: {e}")

    def _create_temp_file(self, code: str, extension: str) -> Path:
        r"""Creates a temporary file with the given code and extension.

        Args:
            code (str): The code to write to the temporary file.
            extension (str): The file extension to use.

        Returns:
            Path: The path to the created temporary file.
        """
        try:
            # Create a temporary directory first to ensure we have write
            # permissions
            temp_dir = tempfile.mkdtemp()
            # Create file path with appropriate extension
            file_path = Path(temp_dir) / f"temp_code.{extension}"

            # Write code to file with appropriate encoding
            with open(file_path, "w", encoding="utf-8") as f:
                f.write(code)

            return file_path
        except Exception as e:
            # Clean up temp directory if creation failed
            if "temp_dir" in locals():
                try:
                    import shutil

                    shutil.rmtree(temp_dir, ignore_errors=True)
                except Exception:
                    pass
            logger.error(f"Failed to create temporary file: {e}")
            raise

    def _check_code_type(self, code_type: str) -> str:
        if code_type not in self._CODE_TYPE_MAPPING:
            raise InterpreterError(
                f"Unsupported code type {code_type}. Currently "
                f"`{self.__class__.__name__}` only supports "
                f"{', '.join(self._CODE_EXTENSION_MAPPING.keys())}."
            )
        return self._CODE_TYPE_MAPPING[code_type]

    def supported_code_types(self) -> list[str]:
        r"""Provides supported code types by the interpreter."""
        return list(self._CODE_EXTENSION_MAPPING.keys())

    def update_action_space(self, action_space: dict[str, Any]) -> None:
        r"""Updates action space for *python* interpreter"""
        raise RuntimeError("SubprocessInterpreter doesn't support `action_space`.")

    def _is_command_available(self, command: str) -> bool:
        r"""Check if a command is available in the system PATH.

        Args:
            command (str): The command to check.

        Returns:
            bool: True if the command is available, False otherwise.
        """
        if os.name == "nt":  # Windows
            # On Windows, use where.exe to find the command
            try:
                with open(os.devnull, "w") as devnull:
                    subprocess.check_call(
                        ["where", command],
                        stdout=devnull,
                        stderr=devnull,
                        shell=False,
                    )
                return True
            except subprocess.CalledProcessError:
                return False
        else:  # Unix-like systems
            # On Unix-like systems, use which to find the command
            try:
                with open(os.devnull, "w") as devnull:
                    subprocess.check_call(
                        ["which", command],
                        stdout=devnull,
                        stderr=devnull,
                        shell=False,
                    )
                return True
            except subprocess.CalledProcessError:
                return False


class CodeExecutionToolkit(AsyncBaseToolkit):
    r"""A tookit for code execution.

    Args:
        sandbox (str): The environment type used to execute code.
        verbose (bool): Whether to print the output of the code execution.
            (default: :obj:`False`)
        unsafe_mode (bool):  If `True`, the interpreter runs the code
            by `eval()` without any security check. (default: :obj:`False`)
        import_white_list ( Optional[List[str]]): A list of allowed imports.
            (default: :obj:`None`)
        require_confirm (bool): Whether to require confirmation before executing code.
            (default: :obj:`False`)
    """

    def __init__(self, config: ToolkitConfig = None) -> None:
        super().__init__(config)
        # sandbox: Literal["subprocess"] = "subprocess",  # "internal_python", "jupyter", "docker", "e2b"
        self.verbose: bool = True
        self.unsafe_mode: bool = False
        self.import_white_list: list[str] | None = []
        self.require_confirm: bool = False
        self.timeout: float | None = None

        working_dir = self.config.config.get("working_dir")
        if working_dir:
            os.makedirs(working_dir, exist_ok=True)

        self.interpreter = SubprocessInterpreter(
            require_confirm=self.require_confirm,
            print_stdout=self.verbose,
            print_stderr=self.verbose,
            working_dir=working_dir,
        )

    async def execute_code(self, code: str) -> str:
        r"""Execute Python code in a secure sandbox environment.

        This tool provides a Python interpreter for executing code safely. Use it to run
        calculations, data processing, analysis, and other computational tasks.

        Args:
            code (str): Valid Python code to execute. To see output values, use print()
                statements explicitly. For image visualization, save images to the current
                directory (./) - they will be automatically displayed to the user.
                Package installation via pip is not permitted.

        Returns:
            str: Execution results including any printed output, return values, or error messages.

        Security Notes:
            - Only execute trusted code
            - Avoid potentially harmful operations
            - File system access is restricted to the working directory
        """
        output = self.interpreter.run(code, "python")
        # ruff: noqa: E501
        # content = f"Executed the code below:\n```py\n{code}\n```\n> Executed Results:\n{output}"
        content = f"> Executed Results:\n{output}"

        if self.verbose:
            print(content)
        return content

    def execute_code_file(self, file_path: str) -> str:
        r"""Execute a given code file.

        Args:
            file_path (str): The path to the code file to execute.

        """
        with open(file_path) as f:
            code = f.read()
        f.close()
        return self.execute_code(code)

    async def get_tools_map(self) -> dict[str, Callable]:
        return {"execute_code": self.execute_code}
